from airhockey import AirHockeyEnv
from airhockey.renderers import AirHockeyRenderer
import yaml
import gymnasium as gym
import numpy as np
from Environment.environment import strip_instance, Reward, Done, non_state_factors
import copy
from Environment.Environments.AirHockey.air_hockey_specs import air_hockey_variants

class AirHockeyObject():
    def __init__(self, name, internal_name, internal_index, internal_keys, initial_state):
        self.name = name
        self.internal_name = internal_name
        self.internal_state = initial_state
        self.internal_index = internal_index
        self.internal_keys = internal_keys
        self.interaction_trace = list()
    
    def set_state(self, state_dict):
        self.internal_state = state_dict[self.internal_name][self.internal_index]
    
    def get_state(self):
        return self.internal_state

def map_name(name):
    oname = strip_instance(name)
    if oname == 'Puck': return "pucks"
    elif oname == 'Paddle': return "paddles"
    elif oname == 'RewardRegion': return "reward_regions"
    elif oname == 'Obstacle': return "obstacles"
    elif oname == 'Target': return "targets"
    elif oname == 'Block': return "blocks"
    else: return oname
    
class AirHockeyGoal():
    def __init__(self, env, goal_type, target_object):
        self.goal_type = goal_type
        self.all_names = env.all_names
        self.all_state_names = [n for n in env.all_names if n not in non_state_factors]
        self.all_state_indices = [i for i, n in enumerate(self.all_names) if n not in non_state_factors]
        self.target_object = target_object
        self.target_idx = env.all_names.index(target_object) if len(target_object) > 0 else -1
        self.attribute = self.sample_goal
        self.goal_epsilon = env.env.max_goal_radius
        
    def set_bounds(self, object_lims, override_lims=None):
        if override_lims != None:
            self.bounds = override_lims
        else:
            if self.target_idx >= 0:
                if self.goal_type == "position":
                    self.bounds = [object_lims[self.target_object][0][:2], object_lims[self.target_object][1][:2]]
                if self.goal_type == "velocity":
                    self.bounds = [object_lims[self.target_object][0][:4], object_lims[self.target_object][1][:4]]
                if self.goal_type == "all":
                    self.bounds = copy.deepcopy(object_lims[self.target_object])
            else:
                if self.goal_type == "position":
                    self.bounds = [np.concatenate([object_lims[strip_instance(o)][0][:2] for o in self.all_names]), 
                                   np.concatenate([object_lims[strip_instance(o)][1][:2] for o in self.all_names])]
                if self.goal_type == "velocity":
                    self.bounds = [np.concatenate([object_lims[strip_instance(o)][0][:4] for o in self.all_names]), 
                                   np.concatenate([object_lims[strip_instance(o)][1][:4] for o in self.all_names])]
                if self.goal_type == "all":
                    self.bounds = [np.concatenate([object_lims[strip_instance(o)][0][:] for o in self.all_names]), 
                                   np.concatenate([object_lims[strip_instance(o)][1][:] for o in self.all_names])]



    def sample_goal(self):
        return (np.random.rand(2) - 0.5) * 2 * self.bounds

    def get_achieved_goal(self, env):
        longest = max([len(env.object_name_dict[n].get_state()) for n in self.all_names])
        state = np.stack([np.pad(env.object_name_dict[n].get_state(), (0,longest - env.object_name_dict[n].get_state().shape[0])) for n in self.all_names], axis=0)
        return self.get_achieved_goal_state(state)

    def get_achieved_goal_state(self, object_state, fidx=None):
        # TODO: implement so that fidx is used
        if self.target_idx >= 0:
            if self.goal_type == "position":
                return object_state[...,self.target_idx,:2]
            if self.goal_type == "velocity":
                return object_state[...,self.target_idx,:4]
            if self.goal_type == "all":
                return object_state[...,self.target_idx,:]
        else:
            if self.goal_type == "position":
                return object_state[...,self.all_state_indices,:2]
            if self.goal_type == "velocity":
                return object_state[...,self.all_state_indices,:4]
            if self.goal_type == "all":
                return object_state[...,self.all_state_indices,:]

    def add_interaction(self, reached_goal):
        if reached_goal:
            if self.target_idx >= 0:
                self.interaction_trace += [self.all_names[self.target_idx]]
            else:
                self.interaction_trace += [n for n in self.all_names if n not in non_state_factors]

    def get_state(self):
        return self.attribute

    def check_goal(self, env):
        return np.all(np.square(self.get_achieved_goal(env) - self.attribute) < self.goal_epsilon)

class AirHockeyWrap(gym.Env):
    def __init__(self, frameskip = 1, horizon=300, variant="", fixed_limits=False):
        ''' required attributes:
            num actions: int or None
            action_space: gym.Spaces
            action_shape: tuple of ints
            observation_space = gym.Spaces
            done: boolean
            reward: int
            seed_counter: int
            discrete_actions: boolean
            name: string
        All the below properties are set by the subclass
        '''
        self.variant = variant
        self.env_config_path, self.goal_type, self.target_object, self.goal_lims, self.max_reward, self.raw_images = air_hockey_variants[variant]
        with open(self.env_config_path, 'r') as f:
            air_hockey_cfg = yaml.safe_load(f)

        air_hockey_params = air_hockey_cfg['air_hockey']
        air_hockey_params['n_training_steps'] = air_hockey_cfg['n_training_steps']
        
        if 'sac' == air_hockey_cfg['algorithm']:
            if 'goal' in air_hockey_cfg['air_hockey']['task']:
                air_hockey_cfg['air_hockey']['return_goal_obs'] = True
            else:
                air_hockey_cfg['air_hockey']['return_goal_obs'] = False
        else:
            air_hockey_cfg['air_hockey']['return_goal_obs'] = False
        
        air_hockey_params_cp = air_hockey_params.copy()
        air_hockey_params_cp['seed'] = 42
        air_hockey_params_cp['max_timesteps'] = horizon
        
        self.env = AirHockeyEnv(air_hockey_params_cp)
        self.renderer = AirHockeyRenderer(self.env)
        self.obs, _ = self.env.reset()

        # environment properties
        self.self_reset = True
        NUM_DISC_ACTIONS = 4
        CONTINUOUS_DIM = 2

        self.num_actions = -1 if air_hockey_params["action_space"] != 'discrete' else NUM_DISC_ACTIONS # assumes discrete is wasd
        self.name = "AirHockey"
        self.fixed_limits = False
        self.discrete_actions = air_hockey_params["action_space"] == 'discrete'
        self.frameskip = frameskip 
        self.transpose = True # transposes the visual domain

        # TODO: we could possible add joint space/end effector space control
        self.action_shape = (1,) if self.discrete_actions else (CONTINUOUS_DIM,)
        self.action_space = gym.spaces.Discrete(self.num_actions) if self.discrete_actions else gym.spaces.Box(low=np.array([-1] * CONTINUOUS_DIM), high=np.array([1] * CONTINUOUS_DIM)) 
        self.observation_space = self.env.observation_space # raw space, gym.spaces
        self.pos_size = 2 # the dimensionality, should be set

        # state components
        self.frame = None # the image generated by the environment
        self.reward = Reward() # TODO: make reward from the environment, not standalone
        self.done = Done() # TODO: make done specialized to store the information from the environment
        self.goal = AirHockeyGoal(self, self.goal_type, self.target_object) if self.env.goal_conditioned else None # TODO: write this, and make sure self.goal sets the goal in the lower level env
        self.action = Action(self.discrete_actions, self.action_shape)
        self.extracted_state = None

        # running values
        self.itr = 0

        # factorized state properties
        paddles = ['Paddle' + str(i) for i in range(self.env.num_paddles)]
        pucks = ['Puck' + str(i) for i in range(self.env.num_pucks)]
        blocks = ['Block' + str(i) for i in range(self.env.num_blocks)]
        obstacles = ['Obstacle' + str(i) for i in range(self.env.num_obstacles)]
        targets = ['Target' + str(i) for i in range(self.env.num_targets)]
        rewardRegions = ['RewardRegion' + str(i) for i in range(len(self.env.reward_regions))] 
        self.all_names = ['Action'] + paddles + pucks + blocks + obstacles \
              + targets + rewardRegions + (["Goal"] if self.env.goal_conditioned else []) + ["Reward", "Done"]
                 
        _, info = self.env.reset()
        initial_state = info['state_info']
        self.valid_names = copy.deepcopy(self.all_names)
        # TODO: storing all of the names, though in reality we might have to filter if there aren't any of that type in the class
        self.object_names = ["Paddle", "Puck", "Block", "Obstacle", "Target", "RewardRegion", "Action", "Goal", "Reward", "Done"] 
        self.num_objects = len(self.all_names)
        self.object_name_dict = {**{name: AirHockeyObject(name, map_name(name), strip_instance(name), ['position', 'velocity'], initial_state) for name in self.all_names if name not in ["Action", "Reward", "Goal", "Done"]},
                                **{"Action": self.action, "Reward": self.reward, "Done": self.done}}
        if self.goal is not None: self.object_name_dict["Goal"] = self.goal
        
        self.object_sizes = {n: 2 for n in self.object_names}
        self.object_sizes["Action"] = self.action_shape[0]
        self.object_sizes["Reward"] = 1
        self.object_sizes["Done"] = 1
        if self.goal is not None: self.object_sizes["Goal"] = self.goal.shape[0]

        self.object_instanced = {"Paddle": len(paddles), "Puck": len(pucks), "Block": len(blocks), 
                                 "Obstacle": len(obstacles), "Target": len(targets), "RewardRegion": len(rewardRegions), 
                                 "Action": 1, "Reward": 1, "Done": 1}
        if self.goal is not None: self.object_instanced["Goal"] = 1
        self.object_proximal = {n: True for n in self.object_names if n not in ["Action", "Reward", "Done"]} # all objects support proximity, to varying degrees
        
        self.object_range, self.object_dynamics = self.init_object_ranges() # TODO: implement
        self.object_range_true, self.object_dynamics_true = copy.deepcopy(self.object_range)
        self.instance_length = len(self.all_names)

        # proximity state components
        self.position_masks = {n: (np.array([1,1,0,0]) if n not in ["Action", "Reward", "Done", "Goal"] else (np.array([0]) if n != "Goal" else [1,1] + ([0] * self.goal.shape[0]))) for n in self.object_names}
        self.pos_size = 2 # the size of the position vector, if used for that object
        self.goal_based = self.env.goal_conditioned
        # TODO: define the goal object
        self.goal_space = self.goal.obs_space 
        self.goal_idx = self.goal.goal_idx 
    
    def init_object_ranges(self, goal_lims = None):
        lims = {"Paddle": [np.array([0.0, self.env.table_y_left, 0.0, 0.0]), np.array([self.env.table_x_bot, self.env.table_y_right, self.env.max_paddle_vel, self.env.max_paddle_vel])],
                "Puck": [np.array([self.env.table_x_top, self.env.table_y_left, -self.env.max_puck_vel, -self.env.max_puck_vel]),
                        np.array([self.env.table_x_bot, self.env.table_y_right, self.env.max_puck_vel, self.env.max_puck_vel])],
                 "Block": [np.array([self.env.table_x_top, self.env.table_y_left, -self.env.max_puck_vel, -self.env.max_puck_vel]),
                        np.array([self.env.table_x_bot, self.env.table_y_right, self.env.max_puck_vel, self.env.max_puck_vel])],
                 "Obstacle": [np.array([self.env.table_x_top, self.env.table_y_left, -self.env.max_puck_vel, -self.env.max_puck_vel]),
                        np.array([self.env.table_x_bot, self.env.table_y_right, self.env.max_puck_vel, self.env.max_puck_vel])],
                 "Target": [np.array([self.env.table_x_top, self.env.table_y_left, -self.env.max_puck_vel, -self.env.max_puck_vel]),
                        np.array([self.env.table_x_bot, self.env.table_y_right, self.env.max_puck_vel, self.env.max_puck_vel])], 
                "RewardRegion": [np.array([self.env.table_x_top, self.env.table_y_left, -self.env.max_puck_vel, -self.env.max_puck_vel]),
                        np.array([self.env.table_x_bot, self.env.table_y_right, self.env.max_puck_vel, self.env.max_puck_vel])],
                 "Action": [np.array([0]), np.array([self.num_actions])] if self.discrete_actions else [np.ones(self.action_shape) * -1, np.ones(self.action_shape)], 
                "Reward": [np.array([0]), np.array([self.max_reward])],
                "Done": [np.array([0]), np.array([1])]}
        self.goal.set_bounds(lims, override_lims = self.goal_lims)
        lims["Goal"] = self.goal.bounds
        
        dynamic_lims = {n: np.concatenate([l[1][2:] - l[0][2:], 3 *(l[1][2:] - l[0][2:])]) for n, l in lims.items() if n not in ["Action", "Reward", "Done", "Goal"]}
        dynamic_lims["Action"] = [- np.array([self.num_actions]), np.array([self.num_actions])] if self.discrete_actions else [np.ones(self.action_shape) * -2, np.ones(self.action_shape) * 2]
        dynamic_lims["Reward"] = [-np.array([self.max_reward]), np.array([self.max_reward])]
        dynamic_lims["Done"] = [-np.array([1]), np.array([1])]
        dynamic_lims["Goal"] = [np.ones(self.goal.bounds[0].shape) * -0.01, np.ones(self.goal.bounds[1].shape) * 0.01] # goal shouldn't change within a trajectory
        
        return lims, dynamic_lims

    def step(self, action):
        '''
        self.save_path is the path to which to save files, and self.itr is the iteration number to be used for saving.
        The format of saving is: folders contain the raw state, names are numbers, contain 2000 raw states each
        obj_dumps contains the factored state
        empty string for save_path means no saving state
        matches the API of OpenAI gym by taking in action (and optional params)
        returns
            state as dict: next raw_state (image or observation) next factor_state (dictionary of name of object to tuple of object bounding box and object property)
            reward: the true reward from the environment
            done flag: if an episode ends, done is True
            info: a dict with additional info
        '''
        total_reward = 0.0
        done = False
        self.clear_interactions()
        for i in range(self.frameskip):
            self.obs, reward, is_finished, truncated, info = self.env.step(action)
            self.set_states(info)
            total_reward += reward
            if self.goal is not None:
                reached = self.goal.check_goal(self)
                self.goal.add_interaction(reached)
            if is_finished or truncated:
                break
        self.done.attribute = np.array([is_finished])
        self.reward.attribute = np.array([total_reward])
        state = self.get_state() 
        return state, reward, is_finished, truncated, info

    def set_states(self, info, action=None, reward=None, done=None):
        for n in self.all_names:
            if n == "Action":
                if action is not None:
                    self.action.set_action(action)
            elif n == "Reward":
                if reward is not None:
                    self.reward.attribute = reward
            elif n == "Done":
                if done is not None:
                    self.done.attribute = done
            elif strip_instance(n) == "Paddle":
                # TODO: we can only handle one paddle because the paddle ego notation is problamatic
                self.object_name_dict[n].attribute = np.concatenate([info['paddles']['paddle_ego']['position'], info['paddles']['paddle_ego']['velocity']])
            else:
                idx_str = n[strip_instance(n):]
                idx = 0 if len(idx_str) == 0 else int(idx_str)
                if strip_instance(n) == "Puck":
                    self.object_name_dict[n].attribute = np.concatenate([info['pucks'][idx]['position'], info['paddles'][idx]['velocity']])
                elif strip_instance(n) == "Block":
                    self.object_name_dict[n].attribute = np.concatenate([info['blocks'][idx]['position'], info['paddles'][idx]['velocity']])
                elif strip_instance(n) == "Obstacle":
                    self.object_name_dict[n].attribute = np.concatenate([info['obstacles'][idx]['position'], info['paddles'][idx]['velocity']])
                elif strip_instance(n) == "Target":
                    self.object_name_dict[n].attribute = np.concatenate([info['targets'][idx]['position'], info['paddles'][idx]['velocity']])
                elif strip_instance(n) == "RewardRegion":
                    self.object_name_dict[n].attribute = np.concatenate([info['reward_regions'][idx]['position'], info['paddles'][idx]['velocity']])

    def reset(self, seed=-1):
        '''
        matches the API of OpenAI gym, resetting the environment
        returns:
            state as dict: next raw_state, next factor_state (dict with corresponding keys)
        '''
        self.obs, info = self.env.reset(seed)
        self.set_states(info)
        if self.goal is not None: self.goal.attribute = self.env.get_desired_goal()
        return self.get_state(), info


    def render(self, mode='human'):
        '''
        matches the API of OpenAI gym, rendering the environment
        returns None for human mode
        '''
        return self.renderer.get_frame()

    def close(self):
        '''
        closes and performs cleanup
        '''
        self.env.close()
        # self.renderer.close()

    def seed(self, seed):
        '''
        numpy should be the only source of randomness, but override if there are more
        '''
        super.seed(seed)
        self.env.reset(seed=seed)


    def get_state(self):
        '''
        Takes in an action and returns:
            dictionary with keys:
                raw_state (dictionary of name of object to raw state)
                factor_state (dictionary of name of object to tuple of object bounding box and object property)
        '''
        raw_state = self.render() if self.raw_images else self.obs
        factored_state = {o.name: o.get_state() for o in self.object_name_dict.values()}
        return {"raw_state": raw_state, "factored_state": factored_state}

    def get_info(self): # returns the info, the most important value is TimeLimit.truncated, can be overriden
        return {"TimeLimit.truncated": False}

    def get_itr(self):
        return self.itr

    def run(self, policy, iterations = 10000):
        
        full_state = self.get_state()
        for self.itr in range(iterations):
            action = policy.act(full_state)
            if action == -1: # signal to quit
                break
            full_state = self.step(action)

    def set_from_factored_state(self, factored_state, valid_names):
        '''
        TODO: implement
        '''

    def demonstrate(self):
        ''' TODO: incomplete
        '''
        return 0
    
    def set_goal_params(self, goal_params):
        # sets parameters like the goal radius, set in subclass
        # TODO: incomplete
        return None
